{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/lamb_optimizer.py:34: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from albert import modeling\n",
    "from albert import optimization\n",
    "from albert import tokenization\n",
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:loading sentence piece model\n"
     ]
    }
   ],
   "source": [
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file='albert-base-2020-04-10/sp10m.cased.v10.vocab', do_lower_case=False,\n",
    "      spm_model_file='albert-base-2020-04-10/sp10m.cased.v10.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:116: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<albert.modeling.AlbertConfig at 0x7f214e19ea20>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_config = modeling.AlbertConfig.from_json_file('albert-base-2020-04-10/config.json')\n",
    "bert_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('albert-squad-test.pkl', 'rb') as fopen:\n",
    "    test_features, test_examples = pickle.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 384\n",
    "doc_stride = 128\n",
    "max_query_length = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 5\n",
    "batch_size = 22\n",
    "warmup_proportion = 0.1\n",
    "n_best_size = 20\n",
    "num_train_steps = int(len(test_features) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.contrib import layers as contrib_layers\n",
    "\n",
    "class Model:\n",
    "    def __init__(self, is_training = True):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.start_positions = tf.placeholder(tf.int32, [None])\n",
    "        self.end_positions = tf.placeholder(tf.int32, [None])\n",
    "        self.p_mask = tf.placeholder(tf.int32, [None, None])\n",
    "        self.is_impossible = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        model = modeling.AlbertModel(\n",
    "            config=bert_config,\n",
    "            is_training=is_training,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        final_hidden = model.get_sequence_output()\n",
    "        self.output = final_hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:194: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:507: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:588: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:1025: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:253: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n"
     ]
    }
   ],
   "source": [
    "learning_rate = 2e-5\n",
    "start_n_top = 5\n",
    "end_n_top = 5\n",
    "is_training = False\n",
    "\n",
    "tf.reset_default_graph()\n",
    "model = Model(is_training = is_training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-10-857e07a2a191>:157: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n"
     ]
    }
   ],
   "source": [
    "output = model.output\n",
    "bsz = tf.shape(output)[0]\n",
    "return_dict = {}\n",
    "output = tf.transpose(output, [1, 0, 2])\n",
    "\n",
    "# invalid position mask such as query and special symbols (PAD, SEP, CLS)\n",
    "p_mask = tf.cast(model.p_mask, dtype = tf.float32)\n",
    "\n",
    "# logit of the start position\n",
    "with tf.variable_scope('start_logits'):\n",
    "    start_logits = tf.layers.dense(\n",
    "        output,\n",
    "        1,\n",
    "        kernel_initializer = modeling.create_initializer(\n",
    "            bert_config.initializer_range\n",
    "        ),\n",
    "    )\n",
    "    start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])\n",
    "    start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask\n",
    "    start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)\n",
    "\n",
    "# logit of the end position\n",
    "with tf.variable_scope('end_logits'):\n",
    "    if is_training:\n",
    "        # during training, compute the end logits based on the\n",
    "        # ground truth of the start position\n",
    "        start_positions = tf.reshape(model.start_positions, [-1])\n",
    "        start_index = tf.one_hot(\n",
    "            start_positions,\n",
    "            depth = max_seq_length,\n",
    "            axis = -1,\n",
    "            dtype = tf.float32,\n",
    "        )\n",
    "        start_features = tf.einsum('lbh,bl->bh', output, start_index)\n",
    "        start_features = tf.tile(\n",
    "            start_features[None], [max_seq_length, 1, 1]\n",
    "        )\n",
    "        end_logits = tf.layers.dense(\n",
    "            tf.concat([output, start_features], axis = -1),\n",
    "            bert_config.hidden_size,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            activation = tf.tanh,\n",
    "            name = 'dense_0',\n",
    "        )\n",
    "        end_logits = contrib_layers.layer_norm(\n",
    "            end_logits, begin_norm_axis = -1\n",
    "        )\n",
    "\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_logits,\n",
    "            1,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            name = 'dense_1',\n",
    "        )\n",
    "        end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])\n",
    "        end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask\n",
    "        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)\n",
    "    else:\n",
    "        # during inference, compute the end logits based on beam search\n",
    "\n",
    "        start_top_log_probs, start_top_index = tf.nn.top_k(\n",
    "            start_log_probs, k = start_n_top\n",
    "        )\n",
    "        start_index = tf.one_hot(\n",
    "            start_top_index,\n",
    "            depth = max_seq_length,\n",
    "            axis = -1,\n",
    "            dtype = tf.float32,\n",
    "        )\n",
    "        start_features = tf.einsum('lbh,bkl->bkh', output, start_index)\n",
    "        end_input = tf.tile(output[:, :, None], [1, 1, start_n_top, 1])\n",
    "        start_features = tf.tile(\n",
    "            start_features[None], [max_seq_length, 1, 1, 1]\n",
    "        )\n",
    "        end_input = tf.concat([end_input, start_features], axis = -1)\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_input,\n",
    "            bert_config.hidden_size,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            activation = tf.tanh,\n",
    "            name = 'dense_0',\n",
    "        )\n",
    "        end_logits = contrib_layers.layer_norm(\n",
    "            end_logits, begin_norm_axis = -1\n",
    "        )\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_logits,\n",
    "            1,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            name = 'dense_1',\n",
    "        )\n",
    "        end_logits = tf.reshape(\n",
    "            end_logits, [max_seq_length, -1, start_n_top]\n",
    "        )\n",
    "        end_logits = tf.transpose(end_logits, [1, 2, 0])\n",
    "        end_logits_masked = (\n",
    "            end_logits * (1 - p_mask[:, None]) - 1e30 * p_mask[:, None]\n",
    "        )\n",
    "        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)\n",
    "        end_top_log_probs, end_top_index = tf.nn.top_k(\n",
    "            end_log_probs, k = end_n_top\n",
    "        )\n",
    "        end_top_log_probs = tf.reshape(\n",
    "            end_top_log_probs, [-1, start_n_top * end_n_top]\n",
    "        )\n",
    "        end_top_index = tf.reshape(\n",
    "            end_top_index, [-1, start_n_top * end_n_top]\n",
    "        )\n",
    "        \n",
    "if is_training:\n",
    "    return_dict['start_log_probs'] = start_log_probs\n",
    "    return_dict['end_log_probs'] = end_log_probs\n",
    "else:\n",
    "    return_dict['start_top_log_probs'] = start_top_log_probs\n",
    "    return_dict['start_top_index'] = start_top_index\n",
    "    return_dict['end_top_log_probs'] = end_top_log_probs\n",
    "    return_dict['end_top_index'] = end_top_index\n",
    "\n",
    "# an additional layer to predict answerability\n",
    "with tf.variable_scope('answer_class'):\n",
    "    # get the representation of CLS\n",
    "    cls_index = tf.one_hot(\n",
    "        tf.zeros([bsz], dtype = tf.int32),\n",
    "        max_seq_length,\n",
    "        axis = -1,\n",
    "        dtype = tf.float32,\n",
    "    )\n",
    "    cls_feature = tf.einsum('lbh,bl->bh', output, cls_index)\n",
    "\n",
    "    # get the representation of START\n",
    "    start_p = tf.nn.softmax(\n",
    "        start_logits_masked, axis = -1, name = 'softmax_start'\n",
    "    )\n",
    "    start_feature = tf.einsum('lbh,bl->bh', output, start_p)\n",
    "\n",
    "    # note(zhiliny): no dependency on end_feature so that we can obtain\n",
    "    # one single `cls_logits` for each sample\n",
    "    ans_feature = tf.concat([start_feature, cls_feature], -1)\n",
    "    ans_feature = tf.layers.dense(\n",
    "        ans_feature,\n",
    "        bert_config.hidden_size,\n",
    "        activation = tf.tanh,\n",
    "        kernel_initializer = modeling.create_initializer(\n",
    "            bert_config.initializer_range\n",
    "        ),\n",
    "        name = 'dense_0',\n",
    "    )\n",
    "    ans_feature = tf.layers.dropout(\n",
    "        ans_feature, bert_config.hidden_dropout_prob, training = is_training\n",
    "    )\n",
    "    cls_logits = tf.layers.dense(\n",
    "        ans_feature,\n",
    "        1,\n",
    "        kernel_initializer = modeling.create_initializer(\n",
    "            bert_config.initializer_range\n",
    "        ),\n",
    "        name = 'dense_1',\n",
    "        use_bias = False,\n",
    "    )\n",
    "    cls_logits = tf.squeeze(cls_logits, -1)\n",
    "    \n",
    "return_dict['cls_logits'] = cls_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_length = tf.shape(model.X)[1]\n",
    "\n",
    "cls_logits = return_dict['cls_logits']\n",
    "is_impossible = tf.reshape(model.is_impossible, [-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from albert-base-squad/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(var_list = tf.trainable_variables())\n",
    "saver.restore(sess, 'albert-base-squad/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bert_utils as squad_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test minibatch loop: 100%|██████████| 559/559 [02:16<00:00,  4.08it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "all_results = []\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_features), batch_size), desc = 'test minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    batch = test_features[i: i + batch_size]\n",
    "    batch_ids = [b.input_ids for b in batch]\n",
    "    batch_masks = [b.input_mask for b in batch]\n",
    "    batch_segment = [b.segment_ids for b in batch]\n",
    "    batch_start = [b.start_position for b in batch]\n",
    "    batch_end = [b.end_position for b in batch]\n",
    "    is_impossible = [b.is_impossible for b in batch]\n",
    "    p_mask = [b.p_mask for b in batch]\n",
    "    o = sess.run(\n",
    "        [start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits],\n",
    "        feed_dict = {\n",
    "            model.X: batch_ids,\n",
    "            model.segment_ids: batch_segment,\n",
    "            model.input_masks: batch_masks,\n",
    "            model.p_mask: p_mask\n",
    "        },\n",
    "    )\n",
    "    for no, b in enumerate(batch):\n",
    "        start_top_log_probs_ = (\n",
    "            [float(x) for x in o[0][no].flat])\n",
    "        start_top_index_ = [int(x) for x in o[1][no].flat]\n",
    "        end_top_log_probs_ = (\n",
    "            [float(x) for x in o[2][no].flat])\n",
    "        end_top_index_ = [int(x) for x in o[3][no].flat]\n",
    "        cls_logits_ = float(o[4][no].flat[0])\n",
    "        all_results.append(squad_utils.RawResultV2(\n",
    "                    unique_id=b.unique_id,\n",
    "                    start_top_log_probs=start_top_log_probs_,\n",
    "                    start_top_index=start_top_index_,\n",
    "                    end_top_log_probs=end_top_log_probs_,\n",
    "                    end_top_index=end_top_index_,\n",
    "                    cls_logits=cls_logits_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_best_size = 20\n",
    "max_answer_length = 30\n",
    "result_dict = {}\n",
    "cls_dict = {}\n",
    "\n",
    "squad_utils.accumulate_predictions_v2(\n",
    "  result_dict, cls_dict, test_examples, test_features,\n",
    "  all_results, n_best_size, max_answer_length,\n",
    "  start_n_top, end_n_top)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('/home/husein/pure-text/ms-dev-2.0.json') as predict_file:\n",
    "    prediction_json = json.load(predict_file)[\"data\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Writing predictions to: predict.json\n",
      "INFO:tensorflow:Writing nbest to: nbest_predictions.json\n"
     ]
    }
   ],
   "source": [
    "output_prediction_file = 'predict.json'\n",
    "output_nbest_file = 'nbest_predictions.json'\n",
    "output_null_log_odds_file = 'null_odds.json'\n",
    "\n",
    "squad_utils.evaluate_v2(\n",
    "  result_dict, cls_dict, prediction_json, test_examples,\n",
    "  test_features, all_results, n_best_size,\n",
    "  max_answer_length, output_prediction_file, output_nbest_file,\n",
    "  output_null_log_odds_file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
